import torch

x = torch.ones((20, 28, 28))
s = x.reshape((20, 1, 28, 28))

print(s.shape)




